import torch
import numpy as np
from torch import nn, optim, autograd


def cross_entropy(p,q):
    return (-p[:,0]*torch.log(q[:,0]) + -p[:,1]*torch.log(q[:,1]) ).mean()


def mean_nll(logits, y):
    #return nn.functional.binary_cross_entropy_with_logits(logits, y)
    return nn.NLLLoss()(logits,y.view(-1).long())

def mean_nll2(logits, y):
    return nn.functional.cross_entropy(logits, y,reduction='none')
def mean_nll2_forIRM(logits, y):
    return nn.functional.binary_cross_entropy_with_logits(logits, y,reduction='none')


def mean_accuracy(logits, y):
    preds = torch.argmax(logits ,dim=1)
    return (preds.view(-1,1) == y).float().mean()




def pretty_print(values):
    col_width = 13
    def format_val(v):
        if not isinstance(v, str):
            v = np.array2string(v, precision=5, floatmode='fixed')
        return v.ljust(col_width)
    str_values = [format_val(v) for v in values]
    print("   ".join(str_values))

def prob_sum(x):
    shape = x.shape[0]
    return (torch.cat([x[:,0].view(-1,shape),x[:,1].view(-1,shape)+x[:,2].view(-1,shape)],dim=0).T)

def softmax(x):
    m = nn.Softmax(dim=1)
    return  m(x)

def Condi_MI(logits,high_env_number):
    reshaped_logits = logits.view(-1,high_env_number,2)

    #caliculate marginal distributions of envs
    envs_Marginal_dist = reshaped_logits.sum(axis=1)
    envs_Marginal_dist = torch.cat([envs_Marginal_dist for i in range(high_env_number)],axis=1)

    #caliculate marginal distributions of labels
    labels_Marginal_dist = reshaped_logits.sum(axis=2)
    labels_Marginal_dist = torch.stack([labels_Marginal_dist,labels_Marginal_dist],axis=1)
    labels_Marginal_dist = torch.transpose(labels_Marginal_dist,2,1).reshape(-1,2*high_env_number)

    #caliculate marginal distributions of (envs and labels)
    Maginal_dist = envs_Marginal_dist*labels_Marginal_dist

    #evaluate Conditional Mutual Information
    return (torch.stack([logits[:,i]*(torch.log(logits[:,i])  - torch.log(Maginal_dist[:,i])) for i in range(2*high_env_number)],axis=1)
            .sum(axis=1)).mean(axis=0)

                